Variational Auto-Encoders

Overview

A VAE has three key ideas worth understanding carefully: the encoder outputs a distribution (not a single point), a clever sampling trick makes that distribution differentiable, and the training objective balances reconstruction quality against keeping the latent space well-structured.

A Variational Autoencoder learns to compress data into a smooth, structured latent space — then decode from it. The key insight is that the encoder doesn’t output a single point, but a probability distribution (a mean \(\mu\) and variance \(\sigma^2\)). This forces the latent space to be continuous and interpolatable.

The key difference from a regular autoencoder is step 2 — instead of encoding \(x\) to a single point \(z\), the encoder outputs a distribution \((\mu, \sigma^2)\). Training then pulls that distribution toward \(\mathcal{N}(0,1)\) via the KL term, which forces the latent space to be smooth and continuous. That’s what makes generation possible: you can sample any point from the prior and decode something meaningful.

The reparameterization trick (\(z = \mu + \sigma \cdot \varepsilon\)) is the clever bit that makes the whole thing trainable with backprop — it moves the stochasticity into \(\varepsilon\), which the gradient doesn’t need to flow through.

Mathematical Formulation

Encoder — approximate posterior:

\[q_\phi(z \mid x) = \mathcal{N}(z;\, \mu_\phi(x),\, \sigma^2_\phi(x) \mathbf{I})\]

Reparameterization trick:

\[z = \mu_\phi(x) + \sigma_\phi(x) \odot \varepsilon, \qquad \varepsilon \sim \mathcal{N}(0, \mathbf{I})\]

Decoder — likelihood:

\[p_\theta(x \mid z) = \mathcal{N}(x;\, \mu_\theta(z),\, \sigma^2 \mathbf{I}) \quad \text{or} \quad \prod_i \text{Bern}(x_i;\, f_\theta(z)_i)\]

Prior:

\[p(z) = \mathcal{N}(0, \mathbf{I})\]

ELBO objective (maximized during training):

\[\mathcal{L}(\theta, \phi;\, x) = \mathbb{E}_{q_\phi(z \mid x)}\bigl[\log p_\theta(x \mid z)\bigr] - D_{\mathrm{KL}}\bigl(q_\phi(z \mid x) \;\|\; p(z)\bigr)\]

KL term in closed form (for Gaussians):

\[D_{\mathrm{KL}}\bigl(\mathcal{N}(\mu, \sigma^2) \;\|\; \mathcal{N}(0,1)\bigr) = -\frac{1}{2}\sum_{j=1}^{J}\left(1 + \log \sigma_j^2 - \mu_j^2 - \sigma_j^2\right)\]

β-VAE generalization:

\[\mathcal{L}_\beta = \mathbb{E}_{q_\phi}\bigl[\log p_\theta(x \mid z)\bigr] - \beta \cdot D_{\mathrm{KL}}\bigl(q_\phi(z \mid x) \;\|\; p(z)\bigr)\]